Add fp32 LM head, docs_per_step accumulation, and document-count metrics#520
Open
jlamypoirier wants to merge 8 commits into
Open
Add fp32 LM head, docs_per_step accumulation, and document-count metrics#520jlamypoirier wants to merge 8 commits into
jlamypoirier wants to merge 8 commits into
Conversation
When True, upcasts the LM head linear's input and weight to FP32 before the matmul, matching vLLM's bf16_last_layer_fp32 quantization. This lets the trainer compute log-probabilities at the same numerical precision as the actor's sampling, so the importance-sampling ratio starts near 1.0 instead of being inflated by trainer/actor precision mismatch. The detached FP32 weight has requires_grad=False, which makes output_parallel_linear_backward skip the weight-grad path. The FSDP gradient contract is restored by computing grad_weight explicitly and accumulating into the original BF16 param's grad_buffer via accumulate_gradient. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
A schedule config field that replaces the static microbatch count with a
runtime document-count target. Matches DeepSpeed's
gradient_accumulation_passes semantics for RL: each microbatch holds one
rollout and the step boundary is set by total rollouts rather than a
fixed microbatch count.
- ScheduleConfig.docs_per_step — when >0, Trainer._prefetch_to_doc_target
fetches microbatches one at a time, all-reduces the per-microbatch doc
count, and stops once the global total reaches the target. The final
step total is broadcast to every microbatch so the loss normalization
stays consistent.
- Trainer._get_or_build_schedule(N) builds and caches a per-N Schedule
with _depth_first_override = N // breadth_first_micro_batches, reusing
the schedule machinery without touching the runner.
- Schedule._eff_{depth_first,sequential_micro_batches,num_inputs} expose
the effective values under an override.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
debc3dd to
5afa8c7
Compare
Surface the per-step document count produced by `_prefetch_to_doc_target` (the loss-normalization denominator) and the cumulative document total as training metrics. Lets the dynamic `docs_per_step` accumulation be verified in production and gives documents-seen as a cross-run x-axis. Gated on `docs_per_step > 0`; no effect on the static-schedule path. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 task
Padding tokens fall past the last real document, so searchsorted assigned them a phantom (num_documents+1)-th index, one past the per-segment buffer sized by num_documents_in_sequence -> CUDA device-side assert in the GSPO index_add_. Clamp the 1-based document index onto the last real document; padding targets are masked so the contribution is zero. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Training-side changes for RL fine-tuning, consolidated into a single PR to
main(previously stacked PRs #526 + #520; #526 is closed in favor of this one).
fp32_lm_head— FP32 LM head logitsNew
fp32_lm_headflag onLanguageModelHeadConfig(defaultFalse). When enabled, the LM headupcasts input and weight to FP32 for the logits projection and casts back, matching vLLM's
bf16_last_layer_fp32, so the trainer computes log-probabilities at the same precision the actorsampled with. Includes the gradient-flow fix for the detached FP32 weight copy — gradients are
accumulated back into the BF16 parameter's buffer.
Dynamic
docs_per_stepaccumulationNew
ScheduleConfig.docs_per_stepfield. When>0, each step accumulates microbatches one at atime, all-reduces the per-microbatch document count, and stops once the global total reaches the
target, instead of using a fixed microbatch count. The final step total is broadcast to every
microbatch so the loss-normalization denominator stays consistent. Off by default
(
docs_per_step=0keeps the original static-schedule path).num_documents/documents_seenmetricsLogs the per-step document count (the divisor
docs_per_stepproduces) and the cumulative documenttotal as training metrics — lets the dynamic accumulation be verified, and gives documents-seen as a
cross-run x-axis. Gated on
docs_per_step>0; no effect on the static path.GSPO segment-index fix for padded sequences
Clamp
global_document_index_qtonum_documents_in_sequenceinfast_llm/data/document/token.py.Padding tokens fall past the last real document, so
searchsortedassigned them a phantomout-of-range segment index, causing a CUDA device-side assert in the GSPO
index_add_. Paddingtargets are masked, so clamping them onto the last real document contributes zero.
Test plan
pytest tests/layers/test_docs_per_step.pySplit from #502.